#=
/opt/openmpi/bin/mpiexec
=#

#在openmpi中运行openblas中的函数时经常有问题，是不是因为我用的是ifort编译的openmpi？
#改用intelmpi或者用MKL
#在利用MPI的情况下，BLAS的多线程实现会很大的干扰性能，用intelmpi+MKL SEQUENTIAL
#MKL的引用放在最上面
using MKL
MKL.set_threading_layer(MKL.THREADING_SEQUENTIAL)

include("../src/qmcmary.jl")
using ..qmcmary

using LinearAlgebra
using Random
using MPI
using Dates



"""
进行前向计算
"""
function forward(sp, tpidx, tpval, Nt, Nx, Ng, psi, the; hsf=missing)
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    comm = MPI.COMM_WORLD
    fname = "tmp$(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm))"
    fil = open(fname, "w")
    close(fil)
    hscfg::Matrix{Int} = Matrix{Int}(undef, Nt, Nx)
    nwarm = 0
    if ismissing(hsf)
        for bi in Base.OneTo(Nt)
            cfg = 2(round.(rand(Nx)) .- 0.5)
            hscfg[bi, :] .= cfg
        end
        nwarm = 100
    else
        hscfg = hsf
    end
    sslen, bmats, allbmats, ss = initialize_SS(
        Nt, Ng, sp, hscfg, nx, ny, nz
    )
    println("rank: $(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm)) tick1: $(Dates.now())")
    for _ in Base.OneTo(nwarm)
        hscfg, bmats, allbmats, ss, gf, ph = dqmc_step(
            Nt, Ng, sp, hscfg, nx, ny, nz, sslen, bmats, allbmats, ss
        )
    end
    println("rank: $(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm)) tick2: $(Dates.now())")
    #
    tph = 0.0
    den = 0.0
    gprm = 0.0
    sdmu = 0.0
    nsp = 100
    tapemap = pbptp_calc_tapemap(sp, tpidx, tpval)
    for idx in Base.OneTo(nsp)
        hscfg, bmats, allbmats, ss, gf, ph = dqmc_step(
            Nt, Ng, sp, hscfg, nx, ny, nz, sslen, bmats, allbmats, ss
        )
        tph += ph
        tpbar = meas_gradtp(ss, allbmats, sp, hscfg, tpidx, tpval; tapemap=tapemap)
        for xi in Base.OneTo(Nx)
            den += ph*(2.0 - gf[xi, xi] - gf[xi+Nx, xi+Nx])/Nx
            #有来有回，有自旋上有自旋下
            gprm += ph*(-gf[tpidx[xi], xi]-gf[xi, tpidx[xi]]-gf[tpidx[xi]+Nx, xi+Nx]-gf[xi+Nx, tpidx[xi]+Nx])/Nx
            #这部分没有符号问题
            sdmu += tpbar[xi]/Nx
        end
    end
    println("rank: $(MPI.Comm_rank(comm))_$(MPI.Comm_size(comm)) tick3: $(Dates.now())")
    #
    tph = tph/nsp
    den = den/nsp/tph
    gprm = gprm/nsp/tph
    #这部分没有符号问题，不做reweight
    sdmu = sdmu/nsp
    res = -Nt*sp.dt*gprm + sdmu
    #println(-Nt*sp.dt*gprm)
    #println(sdmu)
    #println(res)
    return tph, den, gprm, sdmu, res, hscfg
end


function sgn(L)
    MPI.Init()
    Nx = L^2
    tp = 0.25
    hk = lattice_tprim_square(ComplexF64, L, -1.0+0.0im, tp+0.0im)
    #println(hk)
    #
    tpidx = zeros(Int, Nx)
    tpval = tp*ones(Nx)
    ucmap = zeros(Int, L, L)
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        ucmap[ux, uy] = L*(uy-1) + ux
    end; end
    for ux in Base.OneTo(L); for uy in Base.OneTo(L)
        Aidx = ucmap[ux, uy]
        Bidx = ucmap[@↻(L, ux+1), @↻(L, uy+1)]
        tpidx[Aidx] = Bidx
        tpval[Aidx] = tp
    end; end
    #
    Ui = 1.0*ones(Nx)
    Nt = 200
    Ng = 4
    psi = zeros(Nx)
    the = zeros(Nx)
    #
    comm = MPI.COMM_WORLD
    #
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    hsfig = missing
    #
    nbin = 10
    sgnbin = zeros(ComplexF64, MPI.Comm_size(comm), nbin)
    denbin = zeros(ComplexF64, MPI.Comm_size(comm), nbin)
    gprmbin = zeros(ComplexF64, MPI.Comm_size(comm), nbin)
    sdmubin = zeros(ComplexF64, MPI.Comm_size(comm), nbin)
    resbin = zeros(ComplexF64, MPI.Comm_size(comm), nbin)
    for binidx in Base.OneTo(nbin)
        sgn, den, gprm, sdmu, res, hsfig = forward(sp, tpidx, tpval, Nt, Nx, Ng, psi, the; hsf=hsfig)
        sgnbin[MPI.Comm_rank(comm)+1, binidx] = sgn
        denbin[MPI.Comm_rank(comm)+1, binidx] = den
        gprmbin[MPI.Comm_rank(comm)+1, binidx] = gprm
        sdmubin[MPI.Comm_rank(comm)+1, binidx] = sdmu
        resbin[MPI.Comm_rank(comm)+1, binidx] = res
    end
    #MPI.Barrier()
    #println("$(MPI.Comm_rank(comm))", sgnbin)
    sgnbinall = MPI.Reduce(sgnbin, +, 0, comm)
    denbinall = MPI.Reduce(denbin, +, 0, comm)
    gprmbinall = MPI.Reduce(gprmbin, +, 0, comm)
    sdmubinall = MPI.Reduce(sdmubin, +, 0, comm)
    resbinall = MPI.Reduce(resbin, +, 0, comm)
    #println("$(MPI.Comm_rank(comm))", sgnbinall)
    if MPI.Comm_rank(comm) == 0
        totb = nbin * MPI.Comm_size(comm)
        avgsgn = sum(sgnbinall)/totb
        errsgn = sqrt(sum((sgnbinall .- avgsgn).^2)/(totb-1))
        println("sgn: $(avgsgn) +- $(errsgn)")
        avgden = sum(denbinall)/totb
        errden = sqrt(sum((denbinall .- avgden).^2)/(totb-1))
        println("den: $(avgden) +- $(errden)")
        avggprm = sum(gprmbinall)/totb
        errgprm = sqrt(sum((gprmbinall .- avggprm).^2)/(totb-1))
        println("gpm: $(avggprm) +- $(errgprm)")
        avgsdmu = sum(sdmubinall)/totb
        errsdmu = sqrt(sum((sdmubinall .- avgsdmu).^2)/(totb-1))
        println("sdu: $(avgsdmu) +- $(errsdmu)")
        avgres = sum(resbinall)/totb
        errres = sqrt(sum((resbinall .- avgres).^2)/(totb-1))
        println("res: $(avgres) +- $(errres)")
    end
end

sgn(6)
